import torch

### 这种的理解就是一个数组一列
a = torch.tensor([[1,2],[3,4]], dtype=torch.float32)
a8 = torch.mean(a)
print(a8)

a9 = torch.mean(a,dim=(0,1))
print(a9)


a10 = torch.mean(a,dim=0)
print(a10)

a11 = torch.mean(a,dim=1)
print(a11)